{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Sklearn compatible Grid Search for classification\n",
    "\n",
    "Grid search is an in-processing technique that can be used for fair classification or fair regression. For classification it reduces fair classification to a sequence of cost-sensitive classification problems, returning the deterministic classifier with the lowest empirical error subject to fair classification constraints among\n",
    "the candidates searched. The code for grid search wraps the source class `fairlearn.reductions.GridSearch` available in the https://github.com/fairlearn/fairlearn library, licensed under the MIT Licencse, Copyright Microsoft Corporation."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "%matplotlib inline\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "\n",
    "from sklearn.linear_model import LogisticRegression\n",
    "from sklearn.metrics import accuracy_score\n",
    "from sklearn.model_selection import GridSearchCV, train_test_split\n",
    "from sklearn.preprocessing import OneHotEncoder\n",
    "\n",
    "from aif360.sklearn.inprocessing import GridSearchReduction\n",
    "\n",
    "from aif360.sklearn.datasets import fetch_adult\n",
    "from aif360.sklearn.metrics import disparate_impact_ratio, average_odds_error, generalized_fpr\n",
    "from aif360.sklearn.metrics import generalized_fnr, difference"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Loading data"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Datasets are formatted as separate `X` (# samples x # features) and `y` (# samples x # labels) DataFrames. The index of each DataFrame contains protected attribute values per sample. Datasets may also load a `sample_weight` object to be used with certain algorithms/metrics. All of this makes it so that aif360 is compatible with scikit-learn objects.\n",
    "\n",
    "For example, we can easily load the Adult dataset from UCI with the following line:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th>age</th>\n",
       "      <th>workclass</th>\n",
       "      <th>education</th>\n",
       "      <th>education-num</th>\n",
       "      <th>marital-status</th>\n",
       "      <th>occupation</th>\n",
       "      <th>relationship</th>\n",
       "      <th>race</th>\n",
       "      <th>sex</th>\n",
       "      <th>capital-gain</th>\n",
       "      <th>capital-loss</th>\n",
       "      <th>hours-per-week</th>\n",
       "      <th>native-country</th>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th></th>\n",
       "      <th>race</th>\n",
       "      <th>sex</th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <th>Non-white</th>\n",
       "      <th>Male</th>\n",
       "      <td>25.0</td>\n",
       "      <td>Private</td>\n",
       "      <td>11th</td>\n",
       "      <td>7.0</td>\n",
       "      <td>Never-married</td>\n",
       "      <td>Machine-op-inspct</td>\n",
       "      <td>Own-child</td>\n",
       "      <td>Non-white</td>\n",
       "      <td>Male</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>40.0</td>\n",
       "      <td>United-States</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <th>White</th>\n",
       "      <th>Male</th>\n",
       "      <td>38.0</td>\n",
       "      <td>Private</td>\n",
       "      <td>HS-grad</td>\n",
       "      <td>9.0</td>\n",
       "      <td>Married-civ-spouse</td>\n",
       "      <td>Farming-fishing</td>\n",
       "      <td>Husband</td>\n",
       "      <td>White</td>\n",
       "      <td>Male</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>50.0</td>\n",
       "      <td>United-States</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <th>White</th>\n",
       "      <th>Male</th>\n",
       "      <td>28.0</td>\n",
       "      <td>Local-gov</td>\n",
       "      <td>Assoc-acdm</td>\n",
       "      <td>12.0</td>\n",
       "      <td>Married-civ-spouse</td>\n",
       "      <td>Protective-serv</td>\n",
       "      <td>Husband</td>\n",
       "      <td>White</td>\n",
       "      <td>Male</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>40.0</td>\n",
       "      <td>United-States</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <th>Non-white</th>\n",
       "      <th>Male</th>\n",
       "      <td>44.0</td>\n",
       "      <td>Private</td>\n",
       "      <td>Some-college</td>\n",
       "      <td>10.0</td>\n",
       "      <td>Married-civ-spouse</td>\n",
       "      <td>Machine-op-inspct</td>\n",
       "      <td>Husband</td>\n",
       "      <td>Non-white</td>\n",
       "      <td>Male</td>\n",
       "      <td>7688.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>40.0</td>\n",
       "      <td>United-States</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>5</th>\n",
       "      <th>White</th>\n",
       "      <th>Male</th>\n",
       "      <td>34.0</td>\n",
       "      <td>Private</td>\n",
       "      <td>10th</td>\n",
       "      <td>6.0</td>\n",
       "      <td>Never-married</td>\n",
       "      <td>Other-service</td>\n",
       "      <td>Not-in-family</td>\n",
       "      <td>White</td>\n",
       "      <td>Male</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>30.0</td>\n",
       "      <td>United-States</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                   age  workclass     education  education-num  \\\n",
       "  race      sex                                                  \n",
       "0 Non-white Male  25.0    Private          11th            7.0   \n",
       "1 White     Male  38.0    Private       HS-grad            9.0   \n",
       "2 White     Male  28.0  Local-gov    Assoc-acdm           12.0   \n",
       "3 Non-white Male  44.0    Private  Some-college           10.0   \n",
       "5 White     Male  34.0    Private          10th            6.0   \n",
       "\n",
       "                      marital-status         occupation   relationship  \\\n",
       "  race      sex                                                          \n",
       "0 Non-white Male       Never-married  Machine-op-inspct      Own-child   \n",
       "1 White     Male  Married-civ-spouse    Farming-fishing        Husband   \n",
       "2 White     Male  Married-civ-spouse    Protective-serv        Husband   \n",
       "3 Non-white Male  Married-civ-spouse  Machine-op-inspct        Husband   \n",
       "5 White     Male       Never-married      Other-service  Not-in-family   \n",
       "\n",
       "                       race   sex  capital-gain  capital-loss  hours-per-week  \\\n",
       "  race      sex                                                                 \n",
       "0 Non-white Male  Non-white  Male           0.0           0.0            40.0   \n",
       "1 White     Male      White  Male           0.0           0.0            50.0   \n",
       "2 White     Male      White  Male           0.0           0.0            40.0   \n",
       "3 Non-white Male  Non-white  Male        7688.0           0.0            40.0   \n",
       "5 White     Male      White  Male           0.0           0.0            30.0   \n",
       "\n",
       "                 native-country  \n",
       "  race      sex                  \n",
       "0 Non-white Male  United-States  \n",
       "1 White     Male  United-States  \n",
       "2 White     Male  United-States  \n",
       "3 Non-white Male  United-States  \n",
       "5 White     Male  United-States  "
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "X, y, sample_weight = fetch_adult()\n",
    "X.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "# there is one unused category ('Never-worked') that was dropped during dropna\n",
    "X.workclass.cat.remove_unused_categories(inplace=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We can then map the protected attributes to integers,"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "X.index = pd.MultiIndex.from_arrays(X.index.codes, names=X.index.names)\n",
    "y.index = pd.MultiIndex.from_arrays(y.index.codes, names=y.index.names)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "and the target classes to 0/1,"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "y = pd.Series(y.factorize(sort=True)[0], index=y.index)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "split the dataset,"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "(X_train, X_test,\n",
    " y_train, y_test) = train_test_split(X, y, train_size=0.7, random_state=1234567)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We use Pandas for one-hot encoding for easy reference to columns associated with protected attributes, information necessary for grid search reduction."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th>age</th>\n",
       "      <th>education-num</th>\n",
       "      <th>capital-gain</th>\n",
       "      <th>capital-loss</th>\n",
       "      <th>hours-per-week</th>\n",
       "      <th>workclass_Federal-gov</th>\n",
       "      <th>workclass_Local-gov</th>\n",
       "      <th>workclass_Private</th>\n",
       "      <th>workclass_Self-emp-inc</th>\n",
       "      <th>workclass_Self-emp-not-inc</th>\n",
       "      <th>...</th>\n",
       "      <th>native-country_Portugal</th>\n",
       "      <th>native-country_Puerto-Rico</th>\n",
       "      <th>native-country_Scotland</th>\n",
       "      <th>native-country_South</th>\n",
       "      <th>native-country_Taiwan</th>\n",
       "      <th>native-country_Thailand</th>\n",
       "      <th>native-country_Trinadad&amp;Tobago</th>\n",
       "      <th>native-country_United-States</th>\n",
       "      <th>native-country_Vietnam</th>\n",
       "      <th>native-country_Yugoslavia</th>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th></th>\n",
       "      <th>race</th>\n",
       "      <th>sex</th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>30149</th>\n",
       "      <th>1</th>\n",
       "      <th>1</th>\n",
       "      <td>58.0</td>\n",
       "      <td>11.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>42.0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>...</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>12028</th>\n",
       "      <th>1</th>\n",
       "      <th>0</th>\n",
       "      <td>51.0</td>\n",
       "      <td>12.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>30.0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>...</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>36374</th>\n",
       "      <th>1</th>\n",
       "      <th>1</th>\n",
       "      <td>26.0</td>\n",
       "      <td>14.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>1887.0</td>\n",
       "      <td>40.0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>...</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>8055</th>\n",
       "      <th>1</th>\n",
       "      <th>1</th>\n",
       "      <td>44.0</td>\n",
       "      <td>3.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>40.0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>...</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>38108</th>\n",
       "      <th>1</th>\n",
       "      <th>1</th>\n",
       "      <td>33.0</td>\n",
       "      <td>6.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>40.0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>...</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>5 rows × 100 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "                 age  education-num  capital-gain  capital-loss  \\\n",
       "      race sex                                                    \n",
       "30149 1    1    58.0           11.0           0.0           0.0   \n",
       "12028 1    0    51.0           12.0           0.0           0.0   \n",
       "36374 1    1    26.0           14.0           0.0        1887.0   \n",
       "8055  1    1    44.0            3.0           0.0           0.0   \n",
       "38108 1    1    33.0            6.0           0.0           0.0   \n",
       "\n",
       "                hours-per-week  workclass_Federal-gov  workclass_Local-gov  \\\n",
       "      race sex                                                               \n",
       "30149 1    1              42.0                      0                    0   \n",
       "12028 1    0              30.0                      0                    0   \n",
       "36374 1    1              40.0                      0                    0   \n",
       "8055  1    1              40.0                      0                    0   \n",
       "38108 1    1              40.0                      0                    0   \n",
       "\n",
       "                workclass_Private  workclass_Self-emp-inc  \\\n",
       "      race sex                                              \n",
       "30149 1    1                    0                       0   \n",
       "12028 1    0                    0                       0   \n",
       "36374 1    1                    1                       0   \n",
       "8055  1    1                    1                       0   \n",
       "38108 1    1                    1                       0   \n",
       "\n",
       "                workclass_Self-emp-not-inc  ...  native-country_Portugal  \\\n",
       "      race sex                              ...                            \n",
       "30149 1    1                             1  ...                        0   \n",
       "12028 1    0                             1  ...                        0   \n",
       "36374 1    1                             0  ...                        0   \n",
       "8055  1    1                             0  ...                        0   \n",
       "38108 1    1                             0  ...                        0   \n",
       "\n",
       "                native-country_Puerto-Rico  native-country_Scotland  \\\n",
       "      race sex                                                        \n",
       "30149 1    1                             0                        0   \n",
       "12028 1    0                             0                        0   \n",
       "36374 1    1                             0                        0   \n",
       "8055  1    1                             1                        0   \n",
       "38108 1    1                             0                        0   \n",
       "\n",
       "                native-country_South  native-country_Taiwan  \\\n",
       "      race sex                                                \n",
       "30149 1    1                       0                      0   \n",
       "12028 1    0                       0                      0   \n",
       "36374 1    1                       0                      0   \n",
       "8055  1    1                       0                      0   \n",
       "38108 1    1                       0                      0   \n",
       "\n",
       "                native-country_Thailand  native-country_Trinadad&Tobago  \\\n",
       "      race sex                                                            \n",
       "30149 1    1                          0                               0   \n",
       "12028 1    0                          0                               0   \n",
       "36374 1    1                          0                               0   \n",
       "8055  1    1                          0                               0   \n",
       "38108 1    1                          0                               0   \n",
       "\n",
       "                native-country_United-States  native-country_Vietnam  \\\n",
       "      race sex                                                         \n",
       "30149 1    1                               1                       0   \n",
       "12028 1    0                               0                       0   \n",
       "36374 1    1                               1                       0   \n",
       "8055  1    1                               0                       0   \n",
       "38108 1    1                               1                       0   \n",
       "\n",
       "                native-country_Yugoslavia  \n",
       "      race sex                             \n",
       "30149 1    1                            0  \n",
       "12028 1    0                            0  \n",
       "36374 1    1                            0  \n",
       "8055  1    1                            0  \n",
       "38108 1    1                            0  \n",
       "\n",
       "[5 rows x 100 columns]"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "X_train, X_test = pd.get_dummies(X_train), pd.get_dummies(X_test)\n",
    "X_train.head()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The protected attribute information is also replicated in the labels:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "       race  sex\n",
       "30149  1     1      0\n",
       "12028  1     0      1\n",
       "36374  1     1      1\n",
       "8055   1     1      0\n",
       "38108  1     1      0\n",
       "dtype: int64"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "y_train.head()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Running metrics"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "With the data in this format, we can easily train a scikit-learn model and get predictions for the test data:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.8373258642293802\n"
     ]
    }
   ],
   "source": [
    "y_pred = LogisticRegression(solver='lbfgs').fit(X_train, y_train).predict(X_test)\n",
    "lr_acc = accuracy_score(y_test, y_pred)\n",
    "print(lr_acc)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We can assess how close the predictions are to equality of odds.\n",
    "\n",
    "`average_odds_error()` computes the (unweighted) average of the absolute values of the true positive rate (TPR) difference and false positive rate (FPR) difference, i.e.:\n",
    "\n",
    "$$ \\tfrac{1}{2}\\left(|FPR_{D = \\text{unprivileged}} - FPR_{D = \\text{privileged}}| + |TPR_{D = \\text{unprivileged}} - TPR_{D = \\text{privileged}}|\\right) $$"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.10043769764182503\n"
     ]
    }
   ],
   "source": [
    "lr_aoe = average_odds_error(y_test, y_pred, prot_attr='sex')\n",
    "print(lr_aoe)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Grid Search"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Choose a base model for the candidate classifiers. Base models should implement a fit method that can take a sample weight as input. For details refer to the docs. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "estimator = LogisticRegression(solver='lbfgs')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Determine the columns associated with the protected attribute(s). Grid search can handle more then one attribute but it is computationally expensive. A similar method with less computational overhead is exponentiated gradient reduction, detailed at [examples/sklearn/demo_exponentiated_gradient_reduction_sklearn.ipynb](sklearn/demo_exponentiated_gradient_reduction_sklearn.ipynb)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "prot_attr_cols = [colname for colname in X_train if \"sex\" in colname]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Search for the best classifier and observe test accuracy. Other options for `constraints` include \"DemographicParity,\" \"TruePositiveRateDifference\", and \"ErrorRateRatio.\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.8318714527898577\n"
     ]
    }
   ],
   "source": [
    "np.random.seed(0) #need for reproducibility\n",
    "grid_search_red = GridSearchReduction(prot_attr=prot_attr_cols, \n",
    "                                      estimator=estimator, \n",
    "                                      constraints=\"EqualizedOdds\",\n",
    "                                      grid_size=20,\n",
    "                                      drop_prot_attr=False)\n",
    "grid_search_red.fit(X_train, y_train)\n",
    "gs_acc = grid_search_red.score(X_test, y_test)\n",
    "print(gs_acc)\n",
    "\n",
    "#Check if accuracy is comparable\n",
    "assert abs(lr_acc-gs_acc)<0.03"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.0551512399603683\n"
     ]
    }
   ],
   "source": [
    "gs_aoe = average_odds_error(y_test, grid_search_red.predict(X_test), prot_attr='sex')\n",
    "print(gs_aoe)\n",
    "\n",
    "#Check if average odds error improved\n",
    "assert gs_aoe<lr_aoe"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Instead of passing in a value for `constraints`, we can also pass a `fairlearn.reductions.moment` object in for `constraints_moment`. You could use a predefined moment as we do below or create a custom moment using the fairlearn library."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.8318714527898577"
      ]
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import fairlearn.reductions as red \n",
    "\n",
    "np.random.seed(0) #need for reproducibility\n",
    "grid_search_red = GridSearchReduction(prot_attr=prot_attr_cols, \n",
    "                                      estimator=estimator, \n",
    "                                      constraints=red.EqualizedOdds(),\n",
    "                                      grid_size=20,\n",
    "                                      drop_prot_attr=False)\n",
    "grid_search_red.fit(X_train, y_train)\n",
    "grid_search_red.score(X_test, y_test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.0551512399603683"
      ]
     },
     "execution_count": 16,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "average_odds_error(y_test, grid_search_red.predict(X_test), prot_attr='sex')"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}